/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.reactive.socket.client;

import java.net.URI;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.websocket.WebsocketInbound;

import org.springframework.core.io.buffer.NettyDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.adapter.ReactorNettyWebSocketSession;

/**
 * {@link WebSocketClient} implementation for use with Reactor Netty.
 *
 * @author Rossen Stoyanchev
 * @since 5.0
 */
public class ReactorNettyWebSocketClient implements WebSocketClient {

    private static final Log logger = LogFactory.getLog(ReactorNettyWebSocketClient.class);


    private final HttpClient httpClient;


    /**
     * Default constructor.
     */
    public ReactorNettyWebSocketClient() {
        this(HttpClient.create());
    }

    /**
     * Constructor that accepts an existing {@link HttpClient} builder.
     *
     * @since 5.1
     */
    public ReactorNettyWebSocketClient(HttpClient httpClient) {
        Assert.notNull(httpClient, "HttpClient is required");
        this.httpClient = httpClient;
    }


    /**
     * Return the configured {@link HttpClient}.
     */
    public HttpClient getHttpClient() {
        return this.httpClient;
    }


    @Override
    public Mono<Void> execute(URI url, WebSocketHandler handler) {
        return execute(url, new HttpHeaders(), handler);
    }

    @Override
    public Mono<Void> execute(URI url, HttpHeaders requestHeaders, WebSocketHandler handler) {
        return getHttpClient()
                .headers(nettyHeaders -> setNettyHeaders(requestHeaders, nettyHeaders))
                .websocket(StringUtils.collectionToCommaDelimitedString(handler.getSubProtocols()))
                .uri(url.toString())
                .handle((inbound, outbound) -> {
                    HttpHeaders responseHeaders = toHttpHeaders(inbound);
                    String protocol = responseHeaders.getFirst("Sec-WebSocket-Protocol");
                    HandshakeInfo info = new HandshakeInfo(url, responseHeaders, Mono.empty(), protocol);
                    NettyDataBufferFactory factory = new NettyDataBufferFactory(outbound.alloc());
                    WebSocketSession session = new ReactorNettyWebSocketSession(inbound, outbound, info, factory);
                    if (logger.isDebugEnabled()) {
                        logger.debug("Started session '" + session.getId() + "' for " + url);
                    }
                    return handler.handle(session);
                })
                .doOnRequest(n -> {
                    if (logger.isDebugEnabled()) {
                        logger.debug("Connecting to " + url);
                    }
                })
                .next();
    }

    private void setNettyHeaders(HttpHeaders httpHeaders, io.netty.handler.codec.http.HttpHeaders nettyHeaders) {
        httpHeaders.forEach(nettyHeaders::set);
    }

    private HttpHeaders toHttpHeaders(WebsocketInbound inbound) {
        HttpHeaders headers = new HttpHeaders();
        io.netty.handler.codec.http.HttpHeaders nettyHeaders = inbound.headers();
        nettyHeaders.forEach(entry -> {
            String name = entry.getKey();
            headers.put(name, nettyHeaders.getAll(name));
        });
        return headers;
    }

}
